{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hfnet.datasets.hpatches import Hpatches\n",
    "from hfnet.evaluation.loaders import sift_loader, export_loader, fast_loader, harris_loader\n",
    "from hfnet.evaluation.keypoint_detectors import evaluate\n",
    "from hfnet.utils import tools\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_config = {'make_pairs': True, 'shuffle': True}\n",
    "dataset = Hpatches(**data_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_configs = {\n",
    "    'sift': {\n",
    "        'predictor': sift_loader,\n",
    "        'do_nms': False,\n",
    "        'nms_thresh': 8,\n",
    "    },\n",
    "    'fast': {\n",
    "        'predictor': fast_loader,\n",
    "        'do_nms': True,\n",
    "        'nms_thresh': 8,\n",
    "    },\n",
    "    'harris': {\n",
    "        'predictor': harris_loader,\n",
    "        'do_nms': True,\n",
    "        'nms_thresh': 8,\n",
    "    },\n",
    "    'superpoint': {\n",
    "        'experiment': 'super_point_pytorch/hpatches',\n",
    "        'predictor': export_loader,\n",
    "        'do_nms': True,\n",
    "        'nms_thresh': 8,\n",
    "        'remove_borders': 4,\n",
    "    },\n",
    "    'lfnet': {\n",
    "        'experiment': 'lfnet/hpatches_kpts-500',\n",
    "        'predictor': export_loader,\n",
    "    },\n",
    "}\n",
    "eval_config = {'correct_match_thresh': 3, 'num_features': 500}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "methods = ['sift', 'harris', 'fast', 'superpoint', 'lfnet']\n",
    "configs = {m: all_configs[m] for m in methods}\n",
    "for method, config in configs.items():\n",
    "    config = tools.dict_update(config, eval_config)\n",
    "    data_iter = dataset.get_test_set()\n",
    "    metrics, _, _, _ = evaluate(data_iter, config, is_2d=True)\n",
    "    \n",
    "    print('> {}'.format(method))\n",
    "    for k, v in metrics.items():\n",
    "        print('{:<25} {:.3f}'.format(k, v))\n",
    "    print(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "580it [02:53,  4.64it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> sift\n",
      "average_num_keypoints     499.885\n",
      "localization_error        0.968\n",
      "repeatability             0.327\n",
      "mAP                       0.118\n",
      "{'predictor': <function sift_loader at 0x2b8d0d668bf8>, 'do_nms': False, 'nms_thresh': 8, 'correct_match_thresh': 3, 'num_features': 500}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "580it [03:03,  3.94it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> harris\n",
      "average_num_keypoints     477.905\n",
      "localization_error        1.215\n",
      "repeatability             0.493\n",
      "mAP                       0.285\n",
      "{'predictor': <function harris_loader at 0x2b8d0d67b840>, 'do_nms': True, 'nms_thresh': 8, 'correct_match_thresh': 3, 'num_features': 500}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "580it [01:22,  7.00it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> fast\n",
      "average_num_keypoints     484.231\n",
      "localization_error        1.151\n",
      "repeatability             0.402\n",
      "mAP                       0.197\n",
      "{'predictor': <function fast_loader at 0x2b8d0d67b7b8>, 'do_nms': True, 'nms_thresh': 8, 'correct_match_thresh': 3, 'num_features': 500}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "580it [01:38, 13.48it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> superpoint\n",
      "average_num_keypoints     464.888\n",
      "localization_error        1.127\n",
      "repeatability             0.537\n",
      "mAP                       0.342\n",
      "{'experiment': 'super_point_pytorch/hpatches', 'predictor': <function export_loader at 0x2b8d0d67b8c8>, 'do_nms': True, 'nms_thresh': 8, 'remove_borders': 4, 'correct_match_thresh': 3, 'num_features': 500}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "580it [01:01,  9.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> lfnet\n",
      "average_num_keypoints     500.000\n",
      "localization_error        1.184\n",
      "repeatability             0.484\n",
      "mAP                       0.305\n",
      "{'experiment': 'lfnet/hpatches_kpts-500', 'predictor': <function export_loader at 0x2b8d0d67b8c8>, 'correct_match_thresh': 3, 'num_features': 500}\n"
     ]
    }
   ],
   "source": [
    "# NMS=8, N=500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "580it [03:09,  3.93it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> sift\n",
      "average_num_keypoints     300.000\n",
      "localization_error        0.944\n",
      "repeatability             0.307\n",
      "mAP                       0.102\n",
      "{'predictor': <function sift_loader at 0x2b8d0d668bf8>, 'do_nms': False, 'nms_thresh': 8, 'correct_match_thresh': 3, 'num_features': 300}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "580it [02:58,  4.16it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> harris\n",
      "average_num_keypoints     298.979\n",
      "localization_error        1.140\n",
      "repeatability             0.535\n",
      "mAP                       0.334\n",
      "{'predictor': <function harris_loader at 0x2b8d0d67b840>, 'do_nms': True, 'nms_thresh': 4, 'correct_match_thresh': 3, 'num_features': 300}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "580it [01:21,  7.14it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> fast\n",
      "average_num_keypoints     299.851\n",
      "localization_error        1.095\n",
      "repeatability             0.468\n",
      "mAP                       0.260\n",
      "{'predictor': <function fast_loader at 0x2b8d0d67b7b8>, 'do_nms': True, 'nms_thresh': 4, 'correct_match_thresh': 3, 'num_features': 300}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "580it [02:33,  4.89it/s]\n",
      "0it [00:00, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> superpoint\n",
      "average_num_keypoints     298.055\n",
      "localization_error        1.036\n",
      "repeatability             0.495\n",
      "mAP                       0.276\n",
      "{'experiment': 'super_point_pytorch/hpatches', 'predictor': <function export_loader at 0x2b8d0d67b8c8>, 'do_nms': True, 'nms_thresh': 4, 'remove_borders': 4, 'correct_match_thresh': 3, 'num_features': 300}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "580it [01:27,  6.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "> lfnet\n",
      "average_num_keypoints     300.000\n",
      "localization_error        1.127\n",
      "repeatability             0.460\n",
      "mAP                       0.273\n",
      "{'experiment': 'lfnet/hpatches_kpts-500', 'predictor': <function export_loader at 0x2b8d0d67b8c8>, 'correct_match_thresh': 3, 'num_features': 300}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# NMS=4, N=300"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
